diff --git a/GitHubClient.py b/GitHubClient.py index f38db91..15db432 100644 --- a/GitHubClient.py +++ b/GitHubClient.py @@ -17,7 +17,7 @@ class GitHubClient(object): self.before = os.getenv('INPUT_BEFORE') self.sha = os.getenv('INPUT_SHA') self.commits = json.loads(os.getenv('INPUT_COMMITS')) or [] - self.diff_url = os.getenv('INPUT_DIFF_URL') + self.__init_diff_url__() self.token = os.getenv('INPUT_TOKEN') self.issues_url = f'{self.repos_url}{self.repo}/issues' self.milestones_url = f'{self.repos_url}{self.repo}/milestones' @@ -45,6 +45,20 @@ class GitHubClient(object): # Populate milestones so we can perform a lookup if one is specified. self._get_milestones() + def __init_diff_url__(self): + manual_commit_ref = os.getenv('MANUAL_COMMIT_REF') + manual_base_ref = os.getenv('MANUAL_BASE_REF') + if manual_commit_ref: + self.sha = manual_commit_ref + if manual_commit_ref and manual_base_ref: + print(f'Manually comparing {manual_base_ref}...{manual_commit_ref}') + self.diff_url = f'{self.repos_url}{self.repo}/compare/{manual_base_ref}...{manual_commit_ref}' + elif manual_commit_ref: + print(f'Manual checking {manual_commit_ref}') + self.diff_url = f'{self.repos_url}{self.repo}/commits/{manual_commit_ref}' + else: + self.diff_url = os.getenv('INPUT_DIFF_URL') + def get_last_diff(self): """Get the last diff.""" if self.diff_url: @@ -56,10 +70,12 @@ class GitHubClient(object): elif len(self.commits) == 1: # There is only one commit. diff_url = f'{self.repos_url}{self.repo}/commits/{self.sha}' - else: + elif len(self.commits) > 1: # There are several commits: compare with the oldest one. oldest = sorted(self.commits, key=self._get_timestamp)[0]['id'] diff_url = f'{self.repos_url}{self.repo}/compare/{oldest}...{self.sha}' + else: + return None diff_headers = { 'Accept': 'application/vnd.github.v3.diff', diff --git a/LocalClient.py b/LocalClient.py index 9d0b606..b15a495 100644 --- a/LocalClient.py +++ b/LocalClient.py @@ -1,13 +1,36 @@ import subprocess +import os class LocalClient(object): def __init__(self): self.diff_url = None self.commits = ['placeholder'] # content doesn't matter, just length self.insert_issue_urls = False + self.__set_diff_refs__() + + def __set_diff_refs__(self): + # set the target of the comparison to user-specified value, if + # provided, falling back to HEAD + manual_commit_ref = os.getenv('MANUAL_COMMIT_REF') + if manual_commit_ref: + self.sha = manual_commit_ref + else: + self.sha = subprocess.run(['git', 'rev-parse', 'HEAD'], stdout=subprocess.PIPE).stdout.decode('utf-8').strip() + # set the soruce of the comparison to user-specified value, if + # provided, falling back to commit immediately before the target + manual_base_ref = os.getenv('MANUAL_BASE_REF') + if manual_base_ref: + self.base_ref = manual_base_ref + else: + self.base_ref = subprocess.run(['git', 'rev-parse', f'{self.sha}^'], stdout=subprocess.PIPE).stdout.decode('utf-8').strip() + # print feedback to the user + if manual_commit_ref and manual_base_ref: + print(f'Manually comparing {manual_base_ref}...{manual_commit_ref}') + elif manual_commit_ref: + print(f'Manual checking {manual_commit_ref}') def get_last_diff(self): - return subprocess.run(['git', 'diff', 'HEAD^..HEAD'], stdout=subprocess.PIPE).stdout.decode('utf-8') + return subprocess.run(['git', 'diff', f'{self.base_ref}..{self.sha}'], stdout=subprocess.PIPE).stdout.decode('latin-1') def create_issue(self, issue): return [201, None] diff --git a/main.py b/main.py index a16f753..63552d0 100644 --- a/main.py +++ b/main.py @@ -25,23 +25,12 @@ if __name__ == "__main__": # if needed, fall back to using a local client for testing client = client or LocalClient() - # Check to see if the workflow has been run manually. - # If so, adjust the client SHA and diff URL to use the manually supplied inputs. - manual_commit_ref = os.getenv('MANUAL_COMMIT_REF') - manual_base_ref = os.getenv('MANUAL_BASE_REF') - if manual_commit_ref: - client.sha = manual_commit_ref - if manual_commit_ref and manual_base_ref: - print(f'Manually comparing {manual_base_ref}...{manual_commit_ref}') - client.diff_url = f'{client.repos_url}{client.repo}/compare/{manual_base_ref}...{manual_commit_ref}' - elif manual_commit_ref: - print(f'Manual checking {manual_commit_ref}') - client.diff_url = f'{client.repos_url}{client.repo}/commits/{manual_commit_ref}' - if client.diff_url or len(client.commits) != 0: - # Get the diff from the last pushed commit. - last_diff = StringIO(client.get_last_diff()) + # Get the diff from the last pushed commit. + last_diff = client.get_last_diff() + + if last_diff: # Parse the diff for TODOs and create an Issue object for each. - raw_issues = TodoParser().parse(last_diff) + raw_issues = TodoParser().parse(StringIO(last_diff)) # This is a simple, non-perfect check to filter out any TODOs that have just been moved. # It looks for items that appear in the diff as both an addition and deletion. # It is based on the assumption that TODOs will not have identical titles in identical files.